{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Document Question Answering"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "文書による質問応答は、文書による視覚的な質問応答とも呼ばれ、以下を提供するタスクです。\n",
    "ドキュメント画像に関する質問への回答。このタスクをサポートするモデルへの入力は通常、画像と画像の組み合わせです。\n",
    "質問があり、出力は自然言語で表現された回答です。これらのモデルは、以下を含む複数のモダリティを利用します。\n",
    "テキスト、単語の位置 (境界ボックス)、および画像自体。\n",
    "\n",
    "このガイドでは、次の方法を説明します。\n",
    "\n",
    "- [DocVQA データセット](https://huggingface.co/datasets/nielsr/docvqa_1200_examples_donut) の [LayoutLMv2](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/layoutlmv2) を微調整します。\n",
    "- 微調整されたモデルを推論に使用します。\n",
    "\n",
    "<Tip>\n",
    "\n",
    "このタスクと互換性のあるすべてのアーキテクチャとチェックポイントを確認するには、[タスクページ](https://huggingface.co/tasks/image-to-text) を確認することをお勧めします。\n",
    "\n",
    "</Tip>\n",
    "\n",
    "LayoutLMv2 は、最後の非表示のヘッダーの上に質問応答ヘッドを追加することで、ドキュメントの質問応答タスクを解決します。\n",
    "トークンの状態を調べて、トークンの開始トークンと終了トークンの位置を予測します。\n",
    "答え。言い換えれば、問題は抽出的質問応答として扱われます。つまり、コンテキストを考慮して、どの部分を抽出するかということです。\n",
    "の情報が質問に答えます。コンテキストは OCR エンジンの出力から取得されます。ここでは Google の Tesseract です。\n",
    "\n",
    "始める前に、必要なライブラリがすべてインストールされていることを確認してください。 LayoutLMv2 は detectron2、torchvision、tesseract に依存します。\n",
    "\n",
    "```bash\n",
    "pip install -q transformers datasets\n",
    "```\n",
    "\n",
    "```bash\n",
    "pip install 'git+https://github.com/facebookresearch/detectron2.git'\n",
    "pip install torchvision\n",
    "```\n",
    "\n",
    "```bash\n",
    "sudo apt install tesseract-ocr\n",
    "pip install -q pytesseract\n",
    "```\n",
    "\n",
    "すべての依存関係をインストールしたら、ランタイムを再起動します。\n",
    "\n",
    "モデルをコミュニティと共有することをお勧めします。 Hugging Face アカウントにログインして、🤗 ハブにアップロードします。\n",
    "プロンプトが表示されたら、トークンを入力してログインします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from huggingface_hub import notebook_login\n",
    "\n",
    "notebook_login()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "いくつかのグローバル変数を定義しましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_checkpoint = \"microsoft/layoutlmv2-base-uncased\"\n",
    "batch_size = 4"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "このガイドでは、🤗 Hub にある前処理された DocVQA の小さなサンプルを使用します。フルに使いたい場合は、\n",
    "DocVQA データセットは、[DocVQA ホームページ](https://rrc.cvc.uab.es/?ch=17) で登録してダウンロードできます。そうすれば、\n",
    "このガイドを進めて、[🤗 データセットにファイルをロードする方法](https://huggingface.co/docs/datasets/loading#local-and-remote-files) を確認してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DatasetDict({\n",
       "    train: Dataset({\n",
       "        features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],\n",
       "        num_rows: 1000\n",
       "    })\n",
       "    test: Dataset({\n",
       "        features: ['id', 'image', 'query', 'answers', 'words', 'bounding_boxes', 'answer'],\n",
       "        num_rows: 200\n",
       "    })\n",
       "})"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "dataset = load_dataset(\"nielsr/docvqa_1200_examples\")\n",
    "dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ご覧のとおり、データセットはすでにトレーニング セットとテスト セットに分割されています。理解するためにランダムな例を見てみましょう\n",
    "機能を備えた自分自身。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset[\"train\"].features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "個々のフィールドが表す内容は次のとおりです。\n",
    "* `id`: サンプルのID\n",
    "* `image`: ドキュメント画像を含む PIL.Image.Image オブジェクト\n",
    "* `query`: 質問文字列 - いくつかの言語での自然言語による質問\n",
    "* `answers`: ヒューマン アノテーターによって提供された正解のリスト\n",
    "* `words` と `bounding_boxes`: OCR の結果。ここでは使用しません。\n",
    "* `answer`: 別のモデルと一致する答え。ここでは使用しません。\n",
    "\n",
    "英語の質問だけを残し、別のモデルによる予測が含まれていると思われる`answer`機能を削除しましょう。\n",
    "また、アノテーターによって提供されたセットから最初の回答を取得します。あるいは、ランダムにサンプリングすることもできます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_dataset = dataset.map(lambda example: {\"question\": example[\"query\"][\"en\"]}, remove_columns=[\"query\"])\n",
    "updated_dataset = updated_dataset.map(\n",
    "    lambda example: {\"answer\": example[\"answers\"][0]}, remove_columns=[\"answer\", \"answers\"]\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "このガイドで使用する LayoutLMv2 チェックポイントは、`max_position_embeddings = 512` でトレーニングされていることに注意してください (\n",
    "この情報は、[チェックポイントの `config.json` ファイル](https://huggingface.co/microsoft/layoutlmv2-base-uncased/blob/main/config.json#L18)) で見つけてください。\n",
    "例を省略することもできますが、答えが大きな文書の最後にあり、結局省略されてしまうという状況を避けるために、\n",
    "ここでは、埋め込みが 512 を超える可能性があるいくつかの例を削除します。\n",
    "データセット内のほとんどのドキュメントが長い場合は、スライディング ウィンドウ戦略を実装できます。詳細については、[このノートブック](https://github.com/huggingface/notebooks/blob/main/examples/question_answering.ipynb) を確認してください。 。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_dataset = updated_dataset.filter(lambda x: len(x[\"words\"]) + len(x[\"question\"].split()) < 512)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この時点で、このデータセットから OCR 機能も削除しましょう。これらは、異なるデータを微調整するための OCR の結果です。\n",
    "モデル。これらは入力要件と一致しないため、使用したい場合はさらに処理が必要になります。\n",
    "このガイドで使用するモデルの。代わりに、OCR と OCR の両方の元のデータに対して `LayoutLMv2Processor` を使用できます。\n",
    "トークン化。このようにして、モデルの予想される入力と一致する入力を取得します。画像を手動で加工したい場合は、\n",
    "モデルがどのような入力形式を想定しているかを知るには、[`LayoutLMv2` モデルのドキュメント](https://huggingface.co/docs/transformers/main/ja/tasks/../model_doc/layoutlmv2) を確認してください。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_dataset = updated_dataset.remove_columns(\"words\")\n",
    "updated_dataset = updated_dataset.remove_columns(\"bounding_boxes\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最後に、画像サンプルを確認しないとデータ探索は完了しません。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "updated_dataset[\"train\"][11][\"image\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<div class=\"flex justify-center\">\n",
    "     <img src=\"https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/docvqa_example.jpg\" alt=\"DocVQA Image Example\"/>\n",
    " </div>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preprocess the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "文書の質問に答えるタスクはマルチモーダル タスクであるため、各モダリティからの入力が確実に行われるようにする必要があります。\n",
    "モデルの期待に従って前処理されます。まず、`LayoutLMv2Processor` をロードします。これは、画像データを処理できる画像プロセッサとテキスト データをエンコードできるトークナイザーを内部で組み合わせています。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoProcessor\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing document images"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "まず、プロセッサからの `image_processor` を利用して、モデルのドキュメント画像を準備しましょう。\n",
    "デフォルトでは、画像プロセッサは画像のサイズを 224x224 に変更し、カラー チャネルの順序が正しいことを確認します。\n",
    "tesseract を使用して OCR を適用し、単語と正規化された境界ボックスを取得します。このチュートリアルでは、これらのデフォルトはすべて、まさに必要なものです。\n",
    "デフォルトの画像処理を画像のバッチに適用し、OCR の結果を返す関数を作成します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_processor = processor.image_processor\n",
    "\n",
    "\n",
    "def get_ocr_words_and_boxes(examples):\n",
    "    images = [image.convert(\"RGB\") for image in examples[\"image\"]]\n",
    "    encoded_inputs = image_processor(images)\n",
    "\n",
    "    examples[\"image\"] = encoded_inputs.pixel_values\n",
    "    examples[\"words\"] = encoded_inputs.words\n",
    "    examples[\"boxes\"] = encoded_inputs.boxes\n",
    "\n",
    "    return examples"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この前処理をデータセット全体に高速に適用するには、`map` を使用します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataset_with_ocr = updated_dataset.map(get_ocr_words_and_boxes, batched=True, batch_size=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Preprocessing text data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "画像に OCR を適用したら、データセットのテキスト部分をエンコードしてモデル用に準備する必要があります。\n",
    "これには、前のステップで取得した単語とボックスをトークンレベルの `input_ids`、`attention_mask`、\n",
    "`token_type_ids`と`bbox`。テキストを前処理するには、プロセッサからの`Tokenizer`が必要になります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "tokenizer = processor.tokenizer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "前述の前処理に加えて、モデルのラベルを追加する必要もあります。 `xxxForQuestionAnswering` モデルの場合\n",
    "🤗 Transformers では、ラベルは `start_positions` と `end_positions` で構成され、どのトークンがその位置にあるかを示します。\n",
    "開始点と、どのトークンが回答の最後にあるか。\n",
    "\n",
    "それから始めましょう。より大きなリスト (単語リスト) 内のサブリスト (単語に分割された回答) を検索できるヘルパー関数を定義します。\n",
    "\n",
    "この関数は、`words_list` と `answer_list` という 2 つのリストを入力として受け取ります。次に、`words_list`を反復処理してチェックします。\n",
    "`words_list` (words_list[i]) 内の現在の単語が、answer_list (answer_list[0]) の最初の単語と等しいかどうか、および\n",
    "現在の単語から始まり、`answer_list` と同じ長さの `words_list` のサブリストは、`to answer_list` と等しくなります。\n",
    "この条件が true の場合、一致が見つかったことを意味し、関数は一致とその開始インデックス (idx) を記録します。\n",
    "とその終了インデックス (idx + len(answer_list) - 1)。複数の一致が見つかった場合、関数は最初のもののみを返します。\n",
    "一致するものが見つからない場合、関数は (`None`、0、および 0) を返します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def subfinder(words_list, answer_list):\n",
    "    matches = []\n",
    "    start_indices = []\n",
    "    end_indices = []\n",
    "    for idx, i in enumerate(range(len(words_list))):\n",
    "        if words_list[i] == answer_list[0] and words_list[i : i + len(answer_list)] == answer_list:\n",
    "            matches.append(answer_list)\n",
    "            start_indices.append(idx)\n",
    "            end_indices.append(idx + len(answer_list) - 1)\n",
    "    if matches:\n",
    "        return matches[0], start_indices[0], end_indices[0]\n",
    "    else:\n",
    "        return None, 0, 0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この関数が答えの位置を見つける方法を説明するために、例で使用してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Question:  Who is in  cc in this letter?\n",
       "Words: ['wie', 'baw', 'brown', '&', 'williamson', 'tobacco', 'corporation', 'research', '&', 'development', 'internal', 'correspondence', 'to:', 'r.', 'h.', 'honeycutt', 'ce:', 't.f.', 'riehl', 'from:', '.', 'c.j.', 'cook', 'date:', 'may', '8,', '1995', 'subject:', 'review', 'of', 'existing', 'brainstorming', 'ideas/483', 'the', 'major', 'function', 'of', 'the', 'product', 'innovation', 'graup', 'is', 'to', 'develop', 'marketable', 'nove!', 'products', 'that', 'would', 'be', 'profitable', 'to', 'manufacture', 'and', 'sell.', 'novel', 'is', 'defined', 'as:', 'of', 'a', 'new', 'kind,', 'or', 'different', 'from', 'anything', 'seen', 'or', 'known', 'before.', 'innovation', 'is', 'defined', 'as:', 'something', 'new', 'or', 'different', 'introduced;', 'act', 'of', 'innovating;', 'introduction', 'of', 'new', 'things', 'or', 'methods.', 'the', 'products', 'may', 'incorporate', 'the', 'latest', 'technologies,', 'materials', 'and', 'know-how', 'available', 'to', 'give', 'then', 'a', 'unique', 'taste', 'or', 'look.', 'the', 'first', 'task', 'of', 'the', 'product', 'innovation', 'group', 'was', 'to', 'assemble,', 'review', 'and', 'categorize', 'a', 'list', 'of', 'existing', 'brainstorming', 'ideas.', 'ideas', 'were', 'grouped', 'into', 'two', 'major', 'categories', 'labeled', 'appearance', 'and', 'taste/aroma.', 'these', 'categories', 'are', 'used', 'for', 'novel', 'products', 'that', 'may', 'differ', 'from', 'a', 'visual', 'and/or', 'taste/aroma', 'point', 'of', 'view', 'compared', 'to', 'canventional', 'cigarettes.', 'other', 'categories', 'include', 'a', 'combination', 'of', 'the', 'above,', 'filters,', 'packaging', 'and', 'brand', 'extensions.', 'appearance', 'this', 'category', 'is', 'used', 'for', 'novel', 'cigarette', 'constructions', 'that', 'yield', 'visually', 'different', 'products', 'with', 'minimal', 'changes', 'in', 'smoke', 'chemistry', 'two', 'cigarettes', 'in', 'cne.', 'emulti-plug', 'te', 'build', 'yaur', 'awn', 'cigarette.', 'eswitchable', 'menthol', 'or', 'non', 'menthol', 'cigarette.', '*cigarettes', 'with', 'interspaced', 'perforations', 'to', 'enable', 'smoker', 'to', 'separate', 'unburned', 'section', 'for', 'future', 'smoking.', '«short', 'cigarette,', 'tobacco', 'section', '30', 'mm.', '«extremely', 'fast', 'buming', 'cigarette.', '«novel', 'cigarette', 'constructions', 'that', 'permit', 'a', 'significant', 'reduction', 'iretobacco', 'weight', 'while', 'maintaining', 'smoking', 'mechanics', 'and', 'visual', 'characteristics.', 'higher', 'basis', 'weight', 'paper:', 'potential', 'reduction', 'in', 'tobacco', 'weight.', '«more', 'rigid', 'tobacco', 'column;', 'stiffing', 'agent', 'for', 'tobacco;', 'e.g.', 'starch', '*colored', 'tow', 'and', 'cigarette', 'papers;', 'seasonal', 'promotions,', 'e.g.', 'pastel', 'colored', 'cigarettes', 'for', 'easter', 'or', 'in', 'an', 'ebony', 'and', 'ivory', 'brand', 'containing', 'a', 'mixture', 'of', 'all', 'black', '(black', 'paper', 'and', 'tow)', 'and', 'ail', 'white', 'cigarettes.', '499150498']\n",
       "Answer:  T.F. Riehl\n",
       "start_index 17\n",
       "end_index 18"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = dataset_with_ocr[\"train\"][1]\n",
    "words = [word.lower() for word in example[\"words\"]]\n",
    "match, word_idx_start, word_idx_end = subfinder(words, example[\"answer\"].lower().split())\n",
    "print(\"Question: \", example[\"question\"])\n",
    "print(\"Words:\", words)\n",
    "print(\"Answer: \", example[\"answer\"])\n",
    "print(\"start_index\", word_idx_start)\n",
    "print(\"end_index\", word_idx_end)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "ただし、サンプルがエンコードされると、次のようになります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[CLS] who is in cc in this letter? [SEP] wie baw brown & williamson tobacco corporation research & development ..."
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoding = tokenizer(example[\"question\"], example[\"words\"], example[\"boxes\"])\n",
    "tokenizer.decode(encoding[\"input_ids\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "エンコードされた入力内で答えの位置を見つける必要があります。\n",
    "* `token_type_ids` は、どのトークンが質問の一部であり、どのトークンが文書の単語の一部であるかを示します。\n",
    "* `tokenizer.cls_token_id` は、入力の先頭で特別なトークンを見つけるのに役立ちます。\n",
    "* `word_ids` は、元の `words` で見つかった回答を、完全にエンコードされた入力内の同じ回答と照合して判断するのに役立ちます。\n",
    "エンコードされた入力内の応答の開始/終了位置。\n",
    "\n",
    "これを念頭に置いて、データセット内のサンプルのバッチをエンコードする関数を作成しましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def encode_dataset(examples, max_length=512):\n",
    "    questions = examples[\"question\"]\n",
    "    words = examples[\"words\"]\n",
    "    boxes = examples[\"boxes\"]\n",
    "    answers = examples[\"answer\"]\n",
    "\n",
    "    # encode the batch of examples and initialize the start_positions and end_positions\n",
    "    encoding = tokenizer(questions, words, boxes, max_length=max_length, padding=\"max_length\", truncation=True)\n",
    "    start_positions = []\n",
    "    end_positions = []\n",
    "\n",
    "    # loop through the examples in the batch\n",
    "    for i in range(len(questions)):\n",
    "        cls_index = encoding[\"input_ids\"][i].index(tokenizer.cls_token_id)\n",
    "\n",
    "        # find the position of the answer in example's words\n",
    "        words_example = [word.lower() for word in words[i]]\n",
    "        answer = answers[i]\n",
    "        match, word_idx_start, word_idx_end = subfinder(words_example, answer.lower().split())\n",
    "\n",
    "        if match:\n",
    "            # if match is found, use `token_type_ids` to find where words start in the encoding\n",
    "            token_type_ids = encoding[\"token_type_ids\"][i]\n",
    "            token_start_index = 0\n",
    "            while token_type_ids[token_start_index] != 1:\n",
    "                token_start_index += 1\n",
    "\n",
    "            token_end_index = len(encoding[\"input_ids\"][i]) - 1\n",
    "            while token_type_ids[token_end_index] != 1:\n",
    "                token_end_index -= 1\n",
    "\n",
    "            word_ids = encoding.word_ids(i)[token_start_index : token_end_index + 1]\n",
    "            start_position = cls_index\n",
    "            end_position = cls_index\n",
    "\n",
    "            # loop over word_ids and increase `token_start_index` until it matches the answer position in words\n",
    "            # once it matches, save the `token_start_index` as the `start_position` of the answer in the encoding\n",
    "            for id in word_ids:\n",
    "                if id == word_idx_start:\n",
    "                    start_position = token_start_index\n",
    "                else:\n",
    "                    token_start_index += 1\n",
    "\n",
    "            # similarly loop over `word_ids` starting from the end to find the `end_position` of the answer\n",
    "            for id in word_ids[::-1]:\n",
    "                if id == word_idx_end:\n",
    "                    end_position = token_end_index\n",
    "                else:\n",
    "                    token_end_index -= 1\n",
    "\n",
    "            start_positions.append(start_position)\n",
    "            end_positions.append(end_position)\n",
    "\n",
    "        else:\n",
    "            start_positions.append(cls_index)\n",
    "            end_positions.append(cls_index)\n",
    "\n",
    "    encoding[\"image\"] = examples[\"image\"]\n",
    "    encoding[\"start_positions\"] = start_positions\n",
    "    encoding[\"end_positions\"] = end_positions\n",
    "\n",
    "    return encoding"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "この前処理関数が完成したので、データセット全体をエンコードできます。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoded_train_dataset = dataset_with_ocr[\"train\"].map(\n",
    "    encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr[\"train\"].column_names\n",
    ")\n",
    "encoded_test_dataset = dataset_with_ocr[\"test\"].map(\n",
    "    encode_dataset, batched=True, batch_size=2, remove_columns=dataset_with_ocr[\"test\"].column_names\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "エンコードされたデータセットの特徴がどのようなものかを確認してみましょう。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'image': Sequence(feature=Sequence(feature=Sequence(feature=Value(dtype='uint8', id=None), length=-1, id=None), length=-1, id=None), length=-1, id=None),\n",
       " 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),\n",
       " 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),\n",
       " 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),\n",
       " 'bbox': Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None),\n",
       " 'start_positions': Value(dtype='int64', id=None),\n",
       " 'end_positions': Value(dtype='int64', id=None)}"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "encoded_train_dataset.features"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "文書の質問回答の評価には、大量の後処理が必要です。過剰摂取を避けるために\n",
    "現時点では、このガイドでは評価ステップを省略しています。 [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) はトレーニング中に評価損失を計算するため、\n",
    "モデルのパフォーマンスについてまったくわからないわけではありません。抽出的質問応答は通常、F1/完全一致を使用して評価されます。\n",
    "自分で実装したい場合は、[質問応答の章](https://huggingface.co/course/chapter7/7?fw=pt#postprocessing) を確認してください。\n",
    "インスピレーションを得るためにハグフェイスコースの。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "おめでとう！このガイドの最も難しい部分を無事にナビゲートできたので、独自のモデルをトレーニングする準備が整いました。\n",
    "トレーニングには次の手順が含まれます。\n",
    "* 前処理と同じチェックポイントを使用して、[AutoModelForDocumentQuestionAnswering](https://huggingface.co/docs/transformers/main/ja/model_doc/auto#transformers.AutoModelForDocumentQuestionAnswering) でモデルを読み込みます。\n",
    "* [TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) でトレーニング ハイパーパラメータを定義します。\n",
    "* サンプルをバッチ処理する関数を定義します。ここでは `DefaultDataCollat​​or` が適切に機能します。\n",
    "* モデル、データセット、データ照合器とともにトレーニング引数を [Trainer](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer) に渡します。\n",
    "* [train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出してモデルを微調整します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoModelForDocumentQuestionAnswering\n",
    "\n",
    "model = AutoModelForDocumentQuestionAnswering.from_pretrained(model_checkpoint)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[TrainingArguments](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.TrainingArguments) で、`output_dir` を使用してモデルの保存場所を指定し、必要に応じてハイパーパラメーターを構成します。\n",
    "モデルをコミュニティと共有したい場合は、`push_to_hub`を`True`に設定します (モデルをアップロードするには、Hugging Face にサインインする必要があります)。\n",
    "この場合、`output_dir`はモデルのチェックポイントがプッシュされるリポジトリの名前にもなります。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import TrainingArguments\n",
    "\n",
    "# REPLACE THIS WITH YOUR REPO ID\n",
    "repo_id = \"MariaK/layoutlmv2-base-uncased_finetuned_docvqa\"\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir=repo_id,\n",
    "    per_device_train_batch_size=4,\n",
    "    num_train_epochs=20,\n",
    "    save_steps=200,\n",
    "    logging_steps=50,\n",
    "    eval_strategy=\"steps\",\n",
    "    learning_rate=5e-5,\n",
    "    save_total_limit=2,\n",
    "    remove_unused_columns=False,\n",
    "    push_to_hub=True,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "サンプルをまとめてバッチ処理するための単純なデータ照合器を定義します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import DefaultDataCollator\n",
    "\n",
    "data_collator = DefaultDataCollator()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最後に、すべてをまとめて、[train()](https://huggingface.co/docs/transformers/main/ja/main_classes/trainer#transformers.Trainer.train) を呼び出します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import Trainer\n",
    "\n",
    "trainer = Trainer(\n",
    "    model=model,\n",
    "    args=training_args,\n",
    "    data_collator=data_collator,\n",
    "    train_dataset=encoded_train_dataset,\n",
    "    eval_dataset=encoded_test_dataset,\n",
    "    processing_class=processor,\n",
    ")\n",
    "\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "最終モデルを 🤗 Hub に追加するには、モデル カードを作成し、`push_to_hub` を呼び出します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer.create_model_card()\n",
    "trainer.push_to_hub()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Inference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "LayoutLMv2 モデルを微調整し、🤗 ハブにアップロードしたので、それを推論に使用できます。もっとも単純な\n",
    "推論用に微調整されたモデルを試す方法は、それを [Pipeline](https://huggingface.co/docs/transformers/main/ja/main_classes/pipelines#transformers.Pipeline) で使用することです。\n",
    "\n",
    "例を挙げてみましょう:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Who is ‘presiding’ TRRF GENERAL SESSION (PART 1)?'\n",
       "['TRRF Vice President', 'lee a. waller']"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "example = dataset[\"test\"][2]\n",
    "question = example[\"query\"][\"en\"]\n",
    "image = example[\"image\"]\n",
    "print(question)\n",
    "print(example[\"answers\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "次に、パイプラインをインスタンス化します。\n",
    "モデルを使用して質問への回答を文書化し、画像と質問の組み合わせをモデルに渡します。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.9949808120727539,\n",
       "  'answer': 'Lee A. Waller',\n",
       "  'start': 55,\n",
       "  'end': 57}]"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "qa_pipeline = pipeline(\"document-question-answering\", model=\"MariaK/layoutlmv2-base-uncased_finetuned_docvqa\")\n",
    "qa_pipeline(image, question)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "必要に応じて、パイプラインの結果を手動で複製することもできます。\n",
    "1. 画像と質問を取得し、モデルのプロセッサを使用してモデル用に準備します。\n",
    "2. モデルを通じて結果または前処理を転送します。\n",
    "3. モデルは`start_logits`と`end_logits`を返します。これらは、どのトークンが応答の先頭にあるのかを示し、\n",
    "どのトークンが回答の最後にありますか。どちらも形状 (batch_size、sequence_length) を持ちます。\n",
    "4. `start_logits` と `end_logits` の両方の最後の次元で argmax を取得し、予測される `start_idx` と `end_idx` を取得します。\n",
    "5. トークナイザーを使用して回答をデコードします。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'lee a. waller'"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from transformers import AutoProcessor\n",
    "from transformers import AutoModelForDocumentQuestionAnswering\n",
    "\n",
    "processor = AutoProcessor.from_pretrained(\"MariaK/layoutlmv2-base-uncased_finetuned_docvqa\")\n",
    "model = AutoModelForDocumentQuestionAnswering.from_pretrained(\"MariaK/layoutlmv2-base-uncased_finetuned_docvqa\")\n",
    "\n",
    "with torch.no_grad():\n",
    "    encoding = processor(image.convert(\"RGB\"), question, return_tensors=\"pt\")\n",
    "    outputs = model(**encoding)\n",
    "    start_logits = outputs.start_logits\n",
    "    end_logits = outputs.end_logits\n",
    "    predicted_start_idx = start_logits.argmax(-1).item()\n",
    "    predicted_end_idx = end_logits.argmax(-1).item()\n",
    "\n",
    "processor.tokenizer.decode(encoding.input_ids.squeeze()[predicted_start_idx : predicted_end_idx + 1])"
   ]
  }
 ],
 "metadata": {},
 "nbformat": 4,
 "nbformat_minor": 4
}
